1use super::*;
2use crate::linalg::temp_mat_uninit;
3use crate::linalg::zip::Diag;
4use crate::mat::{AsMatMut, MatMut, MatRef};
5use crate::utils::thread::join_raw;
6use crate::{assert, debug_assert, unzip, zip};
7
8#[repr(u8)]
9#[derive(Copy, Clone, Debug)]
10pub(crate) enum DiagonalKind {
11 Zero,
12 Unit,
13 Generic,
14}
15
16#[inline]
17fn pointer_offset<T>(ptr: *const T) -> usize {
18 if try_const! {core::mem::size_of::<T>().is_power_of_two() &&core::mem::size_of::<T>() <= 64 } {
19 ptr.align_offset(64).wrapping_neg() % 16
20 } else {
21 0
22 }
23}
24
25#[faer_macros::math]
26fn copy_lower<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, src_diag: DiagonalKind) {
27 let N = dst.nrows();
28 let mut dst = dst;
29 match src_diag {
30 DiagonalKind::Zero => {
31 dst.copy_from_strict_triangular_lower(src);
32 for j in N.indices() {
33 let zero = zero();
34 dst[(j, j)] = zero;
35 }
36 },
37 DiagonalKind::Unit => {
38 dst.copy_from_strict_triangular_lower(src);
39 for j in N.indices() {
40 let one = one();
41 dst[(j, j)] = one;
42 }
43 },
44 DiagonalKind::Generic => dst.copy_from_triangular_lower(src),
45 }
46
47 zip!(dst).for_each_triangular_upper(Diag::Skip, |unzip!(dst)| *dst = zero());
48}
49
50#[faer_macros::math]
51fn accum_lower<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, skip_diag: bool, beta: Accum) {
52 let N = dst.nrows();
53 debug_assert!(N == dst.nrows());
54 debug_assert!(N == dst.ncols());
55 debug_assert!(N == src.nrows());
56 debug_assert!(N == src.ncols());
57
58 match beta {
59 Accum::Add => {
60 zip!(dst, src).for_each_triangular_lower(if skip_diag { Diag::Skip } else { Diag::Include }, |unzip!(dst, src)| *dst = *dst + *src);
61 },
62 Accum::Replace => {
63 zip!(dst, src).for_each_triangular_lower(if skip_diag { Diag::Skip } else { Diag::Include }, |unzip!(dst, src)| *dst = copy(*src));
64 },
65 }
66}
67
68#[faer_macros::math]
69fn copy_upper<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, src_diag: DiagonalKind) {
70 copy_lower(dst.transpose_mut(), src.transpose(), src_diag)
71}
72
73#[repr(align(64))]
74struct Storage<T>([T; 32 * 16]);
75
76macro_rules! stack_mat_16x16 {
77 ($name: ident, $n: expr, $offset: expr, $rs: expr, $cs: expr, $T: ty $(,)?) => {
78 let mut __tmp = core::mem::MaybeUninit::<Storage<$T>>::uninit();
79 let __stack = MemStack::new_any(core::slice::from_mut(&mut __tmp));
80 let mut $name = unsafe { temp_mat_uninit(32, $n, __stack) }.0;
81 let mut $name = $name.as_mat_mut().subrows_mut($offset, $n);
82 if $cs.unsigned_abs() == 1 {
83 $name = $name.transpose_mut();
84 if $cs == 1 {
85 $name = $name.transpose_mut().reverse_cols_mut();
86 }
87 } else if $rs == -1 {
88 $name = $name.reverse_rows_mut();
89 }
90 };
91}
92
93#[faer_macros::math]
94fn mat_x_lower_impl_unchecked<'M, 'N, T: ComplexField>(
95 dst: MatMut<'_, T, Dim<'M>, Dim<'N>>,
96 beta: Accum,
97 lhs: MatRef<'_, T, Dim<'M>, Dim<'N>>,
98 rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
99 rhs_diag: DiagonalKind,
100 alpha: &T,
101 conj_lhs: Conj,
102 conj_rhs: Conj,
103 par: Par,
104) {
105 let N = rhs.nrows();
106 let M = lhs.nrows();
107 let n = N.unbound();
108 let m = M.unbound();
109 debug_assert!(M == lhs.nrows());
110 debug_assert!(N == lhs.ncols());
111 debug_assert!(N == rhs.nrows());
112 debug_assert!(N == rhs.ncols());
113 debug_assert!(M == dst.nrows());
114 debug_assert!(N == dst.ncols());
115
116 let join_parallelism = if n * n * m < 128usize * 128usize * 64usize { Par::Seq } else { par };
117
118 if n <= 16 {
119 let op = {
120 #[inline(never)]
121 || {
122 stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
123
124 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
125
126 let mut dst = dst;
127 super::matmul_with_conj(dst.rb_mut(), beta, lhs, conj_lhs, temp_rhs.rb(), conj_rhs, alpha.clone(), par);
128 }
129 };
130 op();
131 } else {
132 make_guard!(HEAD);
136 make_guard!(TAIL);
137 let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
138
139 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
140 let (lhs_left, lhs_right) = lhs.split_cols_with(bs);
141 let (mut dst_left, mut dst_right) = dst.split_cols_with_mut(bs);
142
143 {
144 join_raw(
145 |par| mat_x_lower_impl_unchecked(dst_left.rb_mut(), beta, lhs_left, rhs_top_left, rhs_diag, alpha, conj_lhs, conj_rhs, par),
146 |par| {
147 mat_x_lower_impl_unchecked(
148 dst_right.rb_mut(),
149 beta,
150 lhs_right,
151 rhs_bot_right,
152 rhs_diag,
153 alpha,
154 conj_lhs,
155 conj_rhs,
156 par,
157 )
158 },
159 join_parallelism,
160 )
161 };
162
163 super::matmul_with_conj(dst_left, Accum::Add, lhs_right, conj_lhs, rhs_bot_left, conj_rhs, alpha.clone(), par);
164 }
165}
166
167#[faer_macros::math]
168fn lower_x_lower_into_lower_impl_unchecked<'N, T: ComplexField>(
169 dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
170 beta: Accum,
171 skip_diag: bool,
172 lhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
173 lhs_diag: DiagonalKind,
174 rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
175 rhs_diag: DiagonalKind,
176 alpha: &T,
177 conj_lhs: Conj,
178 conj_rhs: Conj,
179 par: Par,
180) {
181 let N = dst.nrows();
182 let n = N.unbound();
183 debug_assert!(N == lhs.nrows());
184 debug_assert!(N == lhs.ncols());
185 debug_assert!(N == rhs.nrows());
186 debug_assert!(N == rhs.ncols());
187 debug_assert!(N == dst.nrows());
188 debug_assert!(N == dst.ncols());
189
190 if n <= 16 {
191 let op = {
192 #[inline(never)]
193 || {
194 stack_mat_16x16!(temp_dst, N, pointer_offset(dst.as_ptr()), dst.row_stride(), dst.col_stride(), T);
195 stack_mat_16x16!(temp_lhs, N, pointer_offset(lhs.as_ptr()), lhs.row_stride(), lhs.col_stride(), T);
196 stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
197
198 copy_lower(temp_lhs.rb_mut(), lhs, lhs_diag);
199 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
200
201 super::matmul_with_conj(
202 temp_dst.rb_mut(),
203 Accum::Replace,
204 temp_lhs.rb(),
205 conj_lhs,
206 temp_rhs.rb(),
207 conj_rhs,
208 alpha.clone(),
209 par,
210 );
211 accum_lower(dst, temp_dst.rb(), skip_diag, beta);
212 }
213 };
214 op();
215 } else {
216 make_guard!(HEAD);
217 make_guard!(TAIL);
218 let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
219
220 let (dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
221 let (lhs_top_left, _, lhs_bot_left, lhs_bot_right) = lhs.split_with(bs, bs);
222 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
223
224 lower_x_lower_into_lower_impl_unchecked(
230 dst_top_left,
231 beta,
232 skip_diag,
233 lhs_top_left,
234 lhs_diag,
235 rhs_top_left,
236 rhs_diag,
237 alpha,
238 conj_lhs,
239 conj_rhs,
240 par,
241 );
242 mat_x_lower_impl_unchecked(
243 dst_bot_left.rb_mut(),
244 beta,
245 lhs_bot_left,
246 rhs_top_left,
247 rhs_diag,
248 alpha,
249 conj_lhs,
250 conj_rhs,
251 par,
252 );
253 mat_x_lower_impl_unchecked(
254 dst_bot_left.reverse_rows_and_cols_mut().transpose_mut(),
255 Accum::Add,
256 rhs_bot_left.reverse_rows_and_cols().transpose(),
257 lhs_bot_right.reverse_rows_and_cols().transpose(),
258 lhs_diag,
259 alpha,
260 conj_rhs,
261 conj_lhs,
262 par,
263 );
264 lower_x_lower_into_lower_impl_unchecked(
265 dst_bot_right,
266 beta,
267 skip_diag,
268 lhs_bot_right,
269 lhs_diag,
270 rhs_bot_right,
271 rhs_diag,
272 alpha,
273 conj_lhs,
274 conj_rhs,
275 par,
276 )
277 }
278}
279
280#[math]
281fn upper_x_lower_impl_unchecked<'N, T: ComplexField>(
282 dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
283 beta: Accum,
284 lhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
285 lhs_diag: DiagonalKind,
286 rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
287 rhs_diag: DiagonalKind,
288 alpha: &T,
289 conj_lhs: Conj,
290 conj_rhs: Conj,
291 par: Par,
292) {
293 let N = dst.nrows();
294 let n = N.unbound();
295 debug_assert!(N == lhs.nrows());
296 debug_assert!(N == lhs.ncols());
297 debug_assert!(N == rhs.nrows());
298 debug_assert!(N == rhs.ncols());
299 debug_assert!(N == dst.nrows());
300 debug_assert!(N == dst.ncols());
301
302 if n <= 16 {
303 let op = {
304 #[inline(never)]
305 || {
306 stack_mat_16x16!(temp_lhs, N, pointer_offset(lhs.as_ptr()), lhs.row_stride(), lhs.col_stride(), T);
307 stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
308
309 copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
310 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
311
312 super::matmul_with_conj(dst, beta, temp_lhs.rb(), conj_lhs, temp_rhs.rb(), conj_rhs, alpha.clone(), par);
313 }
314 };
315 op();
316 } else {
317 make_guard!(HEAD);
318 make_guard!(TAIL);
319 let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
320
321 let (mut dst_top_left, dst_top_right, dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
322 let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_with(bs, bs);
323 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
324
325 join_raw(
333 |par| {
334 super::matmul_with_conj(
335 dst_top_left.rb_mut(),
336 beta,
337 lhs_top_right,
338 conj_lhs,
339 rhs_bot_left,
340 conj_rhs,
341 alpha.clone(),
342 par,
343 );
344 upper_x_lower_impl_unchecked(
345 dst_top_left,
346 Accum::Add,
347 lhs_top_left,
348 lhs_diag,
349 rhs_top_left,
350 rhs_diag,
351 alpha,
352 conj_lhs,
353 conj_rhs,
354 par,
355 )
356 },
357 |par| {
358 join_raw(
359 |par| {
360 mat_x_lower_impl_unchecked(
361 dst_top_right,
362 beta,
363 lhs_top_right,
364 rhs_bot_right,
365 rhs_diag,
366 alpha,
367 conj_lhs,
368 conj_rhs,
369 par,
370 )
371 },
372 |par| {
373 mat_x_lower_impl_unchecked(
374 dst_bot_left.transpose_mut(),
375 beta,
376 rhs_bot_left.transpose(),
377 lhs_bot_right.transpose(),
378 lhs_diag,
379 alpha,
380 conj_rhs,
381 conj_lhs,
382 par,
383 )
384 },
385 par,
386 );
387
388 upper_x_lower_impl_unchecked(
389 dst_bot_right,
390 beta,
391 lhs_bot_right,
392 lhs_diag,
393 rhs_bot_right,
394 rhs_diag,
395 alpha,
396 conj_lhs,
397 conj_rhs,
398 par,
399 )
400 },
401 par,
402 );
403 }
404}
405
406#[math]
407fn upper_x_lower_into_lower_impl_unchecked<'N, T: ComplexField>(
408 dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
409 beta: Accum,
410 skip_diag: bool,
411 lhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
412 lhs_diag: DiagonalKind,
413 rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
414 rhs_diag: DiagonalKind,
415 alpha: &T,
416 conj_lhs: Conj,
417 conj_rhs: Conj,
418 par: Par,
419) {
420 let N = dst.nrows();
421 let n = N.unbound();
422 debug_assert!(N == lhs.nrows());
423 debug_assert!(N == lhs.ncols());
424 debug_assert!(N == rhs.nrows());
425 debug_assert!(N == rhs.ncols());
426 debug_assert!(N == dst.nrows());
427 debug_assert!(N == dst.ncols());
428
429 if n <= 16 {
430 let op = {
431 #[inline(never)]
432 || {
433 stack_mat_16x16!(temp_dst, N, pointer_offset(dst.as_ptr()), dst.row_stride(), dst.col_stride(), T);
434 stack_mat_16x16!(temp_lhs, N, pointer_offset(lhs.as_ptr()), lhs.row_stride(), lhs.col_stride(), T);
435 stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
436
437 copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
438 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
439
440 super::matmul_with_conj(
441 temp_dst.rb_mut(),
442 Accum::Replace,
443 temp_lhs.rb(),
444 conj_lhs,
445 temp_rhs.rb(),
446 conj_rhs,
447 alpha.clone(),
448 par,
449 );
450
451 accum_lower(dst, temp_dst.rb(), skip_diag, beta);
452 }
453 };
454 op();
455 } else {
456 make_guard!(HEAD);
457 make_guard!(TAIL);
458 let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
459
460 let (mut dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
461 let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_with(bs, bs);
462 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
463
464 join_raw(
471 |par| {
472 mat_x_mat_into_lower_impl_unchecked(
473 dst_top_left.rb_mut(),
474 beta,
475 skip_diag,
476 lhs_top_right,
477 rhs_bot_left,
478 alpha,
479 conj_lhs,
480 conj_rhs,
481 par,
482 );
483 upper_x_lower_into_lower_impl_unchecked(
484 dst_top_left,
485 Accum::Add,
486 skip_diag,
487 lhs_top_left,
488 lhs_diag,
489 rhs_top_left,
490 rhs_diag,
491 alpha,
492 conj_lhs,
493 conj_rhs,
494 par,
495 )
496 },
497 |par| {
498 mat_x_lower_impl_unchecked(
499 dst_bot_left.transpose_mut(),
500 beta,
501 rhs_bot_left.transpose(),
502 lhs_bot_right.transpose(),
503 lhs_diag,
504 alpha,
505 conj_rhs,
506 conj_lhs,
507 par,
508 );
509 upper_x_lower_into_lower_impl_unchecked(
510 dst_bot_right,
511 beta,
512 skip_diag,
513 lhs_bot_right,
514 lhs_diag,
515 rhs_bot_right,
516 rhs_diag,
517 alpha,
518 conj_lhs,
519 conj_rhs,
520 par,
521 )
522 },
523 par,
524 );
525 }
526}
527
528#[math]
529fn mat_x_mat_into_lower_impl_unchecked<'N, 'K, T: ComplexField>(
530 dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
531 beta: Accum,
532 skip_diag: bool,
533 lhs: MatRef<'_, T, Dim<'N>, Dim<'K>>,
534 rhs: MatRef<'_, T, Dim<'K>, Dim<'N>>,
535 alpha: &T,
536 conj_lhs: Conj,
537 conj_rhs: Conj,
538 par: Par,
539) {
540 #[cfg(all(target_arch = "x86_64", feature = "std"))]
541 if const { T::IS_NATIVE_F64 || T::IS_NATIVE_F32 || T::IS_NATIVE_C64 || T::IS_NATIVE_C32 } {
542 use private_gemm_x86::*;
543
544 let feat = if std::arch::is_x86_feature_detected!("avx512f") {
545 Some(InstrSet::Avx512)
546 } else if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
547 Some(InstrSet::Avx256)
548 } else {
549 None
550 };
551
552 if *dst.nrows() > 0 && *dst.ncols() > 0 && *lhs.ncols() > 0 {
553 if let Some(feat) = feat {
554 unsafe {
555 let (dst, lhs) = if skip_diag {
556 (dst.as_dyn_mut().get_mut(1.., ..), lhs.as_dyn().get(1.., ..))
557 } else {
558 (dst.as_dyn_mut(), lhs.as_dyn())
559 };
560
561 private_gemm_x86::gemm(
562 const {
563 if T::IS_NATIVE_F64 {
564 DType::F64
565 } else if T::IS_NATIVE_F32 {
566 DType::F32
567 } else if T::IS_NATIVE_C64 {
568 DType::C64
569 } else {
570 DType::C32
571 }
572 },
573 const { IType::U32 },
574 feat,
575 dst.nrows(),
576 dst.ncols(),
577 lhs.ncols(),
578 dst.as_ptr_mut() as _,
579 dst.row_stride(),
580 dst.col_stride(),
581 core::ptr::null(),
582 core::ptr::null(),
583 DstKind::Lower,
584 match beta {
585 crate::Accum::Replace => Accum::Replace,
586 crate::Accum::Add => Accum::Add,
587 },
588 lhs.as_ptr() as _,
589 lhs.row_stride(),
590 lhs.col_stride(),
591 conj_lhs == Conj::Yes,
592 core::ptr::null(),
593 0,
594 rhs.as_ptr() as _,
595 rhs.row_stride(),
596 rhs.col_stride(),
597 conj_rhs == Conj::Yes,
598 alpha as *const T as *const (),
599 par.degree(),
600 );
601
602 return;
603 }
604 }
605 }
606 }
607
608 let N = dst.nrows();
609 let K = lhs.ncols();
610 let n = N.unbound();
611 let k = K.unbound();
612 debug_assert!(dst.nrows() == dst.ncols());
613 debug_assert!(dst.nrows() == lhs.nrows());
614 debug_assert!(dst.ncols() == rhs.ncols());
615 debug_assert!(lhs.ncols() == rhs.nrows());
616
617 let par = if n * n * k < 128usize * 128usize * 128usize { Par::Seq } else { par };
618
619 if n <= 16 {
620 let op = {
621 #[inline(never)]
622 || {
623 stack_mat_16x16!(temp_dst, N, pointer_offset(dst.as_ptr()), dst.row_stride(), dst.col_stride(), T);
624
625 super::matmul_with_conj(temp_dst.rb_mut(), Accum::Replace, lhs, conj_lhs, rhs, conj_rhs, alpha.clone(), par);
626 accum_lower(dst, temp_dst.rb(), skip_diag, beta);
627 }
628 };
629 op();
630 } else {
631 make_guard!(HEAD);
632 make_guard!(TAIL);
633 let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
634
635 let (dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
636 let (lhs_top, lhs_bot) = lhs.split_rows_with(bs);
637 let (rhs_left, rhs_right) = rhs.split_cols_with(bs);
638
639 join_raw(
640 |par| super::matmul_with_conj(dst_bot_left, beta, lhs_bot, conj_lhs, rhs_left, conj_rhs, alpha.clone(), par),
641 |par| {
642 join_raw(
643 |par| mat_x_mat_into_lower_impl_unchecked(dst_top_left, beta, skip_diag, lhs_top, rhs_left, alpha, conj_lhs, conj_rhs, par),
644 |par| mat_x_mat_into_lower_impl_unchecked(dst_bot_right, beta, skip_diag, lhs_bot, rhs_right, alpha, conj_lhs, conj_rhs, par),
645 par,
646 )
647 },
648 par,
649 );
650 }
651}
652
653#[math]
654fn mat_x_lower_into_lower_impl_unchecked<'N, T: ComplexField>(
655 dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
656 beta: Accum,
657 skip_diag: bool,
658 lhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
659 rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
660 rhs_diag: DiagonalKind,
661 alpha: &T,
662 conj_lhs: Conj,
663 conj_rhs: Conj,
664 par: Par,
665) {
666 let N = dst.nrows();
667 let n = N.unbound();
668 debug_assert!(N == dst.nrows());
669 debug_assert!(N == dst.ncols());
670 debug_assert!(N == lhs.nrows());
671 debug_assert!(N == lhs.ncols());
672 debug_assert!(N == rhs.nrows());
673 debug_assert!(N == rhs.ncols());
674
675 if n <= 16 {
676 let op = {
677 #[inline(never)]
678 || {
679 stack_mat_16x16!(temp_dst, N, pointer_offset(dst.as_ptr()), dst.row_stride(), dst.col_stride(), T);
680 stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
681
682 copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
683 super::matmul_with_conj(
684 temp_dst.rb_mut(),
685 Accum::Replace,
686 lhs,
687 conj_lhs,
688 temp_rhs.rb(),
689 conj_rhs,
690 alpha.clone(),
691 par,
692 );
693 accum_lower(dst, temp_dst.rb(), skip_diag, beta);
694 }
695 };
696 op();
697 } else {
698 make_guard!(HEAD);
699 make_guard!(TAIL);
700 let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
701
702 let (mut dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
703 let (lhs_top_left, lhs_top_right, lhs_bot_left, lhs_bot_right) = lhs.split_with(bs, bs);
704 let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
705
706 super::matmul_with_conj(
714 dst_bot_left.rb_mut(),
715 beta,
716 lhs_bot_right,
717 conj_lhs,
718 rhs_bot_left,
719 conj_rhs,
720 alpha.clone(),
721 par,
722 );
723 mat_x_lower_into_lower_impl_unchecked(
724 dst_bot_right,
725 beta,
726 skip_diag,
727 lhs_bot_right,
728 rhs_bot_right,
729 rhs_diag,
730 alpha,
731 conj_lhs,
732 conj_rhs,
733 par,
734 );
735
736 mat_x_lower_into_lower_impl_unchecked(
737 dst_top_left.rb_mut(),
738 beta,
739 skip_diag,
740 lhs_top_left,
741 rhs_top_left,
742 rhs_diag,
743 alpha,
744 conj_lhs,
745 conj_rhs,
746 par,
747 );
748 mat_x_mat_into_lower_impl_unchecked(
749 dst_top_left,
750 Accum::Add,
751 skip_diag,
752 lhs_top_right,
753 rhs_bot_left,
754 alpha,
755 conj_lhs,
756 conj_rhs,
757 par,
758 );
759 mat_x_lower_impl_unchecked(
760 dst_bot_left,
761 Accum::Add,
762 lhs_bot_left,
763 rhs_top_left,
764 rhs_diag,
765 alpha,
766 conj_lhs,
767 conj_rhs,
768 par,
769 );
770 }
771}
772
773#[derive(Debug, Clone, Copy, PartialEq, Eq)]
775pub enum BlockStructure {
776 Rectangular,
778 TriangularLower,
780 StrictTriangularLower,
782 UnitTriangularLower,
785 TriangularUpper,
787 StrictTriangularUpper,
789 UnitTriangularUpper,
792}
793
794impl BlockStructure {
795 #[inline]
797 pub fn is_dense(self) -> bool {
798 matches!(self, BlockStructure::Rectangular)
799 }
800
801 #[inline]
803 pub fn is_lower(self) -> bool {
804 use BlockStructure::*;
805 matches!(self, TriangularLower | StrictTriangularLower | UnitTriangularLower)
806 }
807
808 #[inline]
810 pub fn is_upper(self) -> bool {
811 use BlockStructure::*;
812 matches!(self, TriangularUpper | StrictTriangularUpper | UnitTriangularUpper)
813 }
814
815 #[inline]
817 pub fn transpose(self) -> Self {
818 use BlockStructure::*;
819 match self {
820 Rectangular => Rectangular,
821 TriangularLower => TriangularUpper,
822 StrictTriangularLower => StrictTriangularUpper,
823 UnitTriangularLower => UnitTriangularUpper,
824 TriangularUpper => TriangularLower,
825 StrictTriangularUpper => StrictTriangularLower,
826 UnitTriangularUpper => UnitTriangularLower,
827 }
828 }
829
830 #[inline]
831 pub(crate) fn diag_kind(self) -> DiagonalKind {
832 use BlockStructure::*;
833 match self {
834 Rectangular | TriangularLower | TriangularUpper => DiagonalKind::Generic,
835 StrictTriangularLower | StrictTriangularUpper => DiagonalKind::Zero,
836 UnitTriangularLower | UnitTriangularUpper => DiagonalKind::Unit,
837 }
838 }
839}
840
841#[track_caller]
842fn precondition<M: Shape, N: Shape, K: Shape>(
843 dst_nrows: M,
844 dst_ncols: N,
845 dst_structure: BlockStructure,
846 lhs_nrows: M,
847 lhs_ncols: K,
848 lhs_structure: BlockStructure,
849 rhs_nrows: K,
850 rhs_ncols: N,
851 rhs_structure: BlockStructure,
852) {
853 assert!(all(dst_nrows == lhs_nrows, dst_ncols == rhs_ncols, lhs_ncols == rhs_nrows,));
854
855 let dst_nrows = dst_nrows.unbound();
856 let dst_ncols = dst_ncols.unbound();
857 let lhs_nrows = lhs_nrows.unbound();
858 let lhs_ncols = lhs_ncols.unbound();
859 let rhs_nrows = rhs_nrows.unbound();
860 let rhs_ncols = rhs_ncols.unbound();
861
862 if !dst_structure.is_dense() {
863 assert!(dst_nrows == dst_ncols);
864 }
865 if !lhs_structure.is_dense() {
866 assert!(lhs_nrows == lhs_ncols);
867 }
868 if !rhs_structure.is_dense() {
869 assert!(rhs_nrows == rhs_ncols);
870 }
871}
872
873#[track_caller]
940#[inline]
941pub fn matmul_with_conj<T: ComplexField, M: Shape, N: Shape, K: Shape>(
942 dst: impl AsMatMut<T = T, Rows = M, Cols = N>,
943 dst_structure: BlockStructure,
944 beta: Accum,
945 lhs: impl AsMatRef<T = T, Rows = M, Cols = K>,
946 lhs_structure: BlockStructure,
947 conj_lhs: Conj,
948 rhs: impl AsMatRef<T = T, Rows = K, Cols = N>,
949 rhs_structure: BlockStructure,
950 conj_rhs: Conj,
951 alpha: T,
952 par: Par,
953) {
954 let mut dst = dst;
955 let dst = dst.as_mat_mut();
956 let lhs = lhs.as_mat_ref();
957 let rhs = rhs.as_mat_ref();
958
959 precondition(
960 dst.nrows(),
961 dst.ncols(),
962 dst_structure,
963 lhs.nrows(),
964 lhs.ncols(),
965 lhs_structure,
966 rhs.nrows(),
967 rhs.ncols(),
968 rhs_structure,
969 );
970
971 make_guard!(M);
972 make_guard!(N);
973 make_guard!(K);
974 let M = dst.nrows().bind(M);
975 let N = dst.ncols().bind(N);
976 let K = lhs.ncols().bind(K);
977
978 matmul_imp(
979 dst.as_dyn_stride_mut().as_shape_mut(M, N),
980 dst_structure,
981 beta,
982 lhs.as_dyn_stride().canonical().as_shape(M, K),
983 lhs_structure,
984 conj_lhs,
985 rhs.as_dyn_stride().canonical().as_shape(K, N),
986 rhs_structure,
987 conj_rhs,
988 &alpha,
989 par,
990 );
991}
992
993#[track_caller]
1058#[inline]
1059pub fn matmul<T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>, M: Shape, N: Shape, K: Shape>(
1060 dst: impl AsMatMut<T = T, Rows = M, Cols = N>,
1061 dst_structure: BlockStructure,
1062 beta: Accum,
1063 lhs: impl AsMatRef<T = LhsT, Rows = M, Cols = K>,
1064 lhs_structure: BlockStructure,
1065 rhs: impl AsMatRef<T = RhsT, Rows = K, Cols = N>,
1066 rhs_structure: BlockStructure,
1067 alpha: T,
1068 par: Par,
1069) {
1070 let mut dst = dst;
1071 let dst = dst.as_mat_mut();
1072 let lhs = lhs.as_mat_ref();
1073 let rhs = rhs.as_mat_ref();
1074
1075 precondition(
1076 dst.nrows(),
1077 dst.ncols(),
1078 dst_structure,
1079 lhs.nrows(),
1080 lhs.ncols(),
1081 lhs_structure,
1082 rhs.nrows(),
1083 rhs.ncols(),
1084 rhs_structure,
1085 );
1086
1087 make_guard!(M);
1088 make_guard!(N);
1089 make_guard!(K);
1090 let M = dst.nrows().bind(M);
1091 let N = dst.ncols().bind(N);
1092 let K = lhs.ncols().bind(K);
1093
1094 matmul_imp(
1095 dst.as_dyn_stride_mut().as_shape_mut(M, N),
1096 dst_structure,
1097 beta,
1098 lhs.as_dyn_stride().canonical().as_shape(M, K),
1099 lhs_structure,
1100 try_const! { Conj::get::<LhsT>() },
1101 rhs.as_dyn_stride().canonical().as_shape(K, N),
1102 rhs_structure,
1103 try_const! { Conj::get::<RhsT>() },
1104 alpha.by_ref(),
1105 par,
1106 );
1107}
1108
1109#[math]
1110fn matmul_imp<'M, 'N, 'K, T: ComplexField>(
1111 dst: MatMut<'_, T, Dim<'M>, Dim<'N>>,
1112 dst_structure: BlockStructure,
1113 beta: Accum,
1114 lhs: MatRef<'_, T, Dim<'M>, Dim<'K>>,
1115 lhs_structure: BlockStructure,
1116 conj_lhs: Conj,
1117 rhs: MatRef<'_, T, Dim<'K>, Dim<'N>>,
1118 rhs_structure: BlockStructure,
1119 conj_rhs: Conj,
1120 alpha: &T,
1121 par: Par,
1122) {
1123 let mut acc = dst.as_dyn_mut();
1124 let mut lhs = lhs.as_dyn();
1125 let mut rhs = rhs.as_dyn();
1126
1127 let mut acc_structure = dst_structure;
1128 let mut lhs_structure = lhs_structure;
1129 let mut rhs_structure = rhs_structure;
1130
1131 let mut conj_lhs = conj_lhs;
1132 let mut conj_rhs = conj_rhs;
1133
1134 if rhs_structure.is_lower() {
1136 false
1138 } else if rhs_structure.is_upper() {
1139 acc = acc.reverse_rows_and_cols_mut();
1141 lhs = lhs.reverse_rows_and_cols();
1142 rhs = rhs.reverse_rows_and_cols();
1143 acc_structure = acc_structure.transpose();
1144 lhs_structure = lhs_structure.transpose();
1145 rhs_structure = rhs_structure.transpose();
1146 false
1147 } else if lhs_structure.is_lower() {
1148 acc = acc.reverse_rows_and_cols_mut().transpose_mut();
1150 (lhs, rhs) = (rhs.reverse_rows_and_cols().transpose(), lhs.reverse_rows_and_cols().transpose());
1151 (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
1152 (lhs_structure, rhs_structure) = (rhs_structure, lhs_structure);
1153 true
1154 } else if lhs_structure.is_upper() {
1155 acc_structure = acc_structure.transpose();
1157 acc = acc.transpose_mut();
1158 (lhs, rhs) = (rhs.transpose(), lhs.transpose());
1159 (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
1160 (lhs_structure, rhs_structure) = (rhs_structure.transpose(), lhs_structure.transpose());
1161 true
1162 } else {
1163 false
1165 };
1166
1167 make_guard!(M);
1168 make_guard!(N);
1169 make_guard!(K);
1170 let M = acc.nrows().bind(M);
1171 let N = acc.ncols().bind(N);
1172 let K = lhs.ncols().bind(K);
1173
1174 let clear_upper = |acc: MatMut<'_, T>, skip_diag: bool| match &beta {
1175 Accum::Add => {},
1176
1177 Accum::Replace => zip!(acc).for_each_triangular_upper(if skip_diag { Diag::Skip } else { Diag::Include }, |unzip!(acc)| *acc = zero()),
1178 };
1179
1180 let skip_diag = matches!(
1181 acc_structure,
1182 BlockStructure::StrictTriangularLower
1183 | BlockStructure::StrictTriangularUpper
1184 | BlockStructure::UnitTriangularLower
1185 | BlockStructure::UnitTriangularUpper
1186 );
1187 let lhs_diag = lhs_structure.diag_kind();
1188 let rhs_diag = rhs_structure.diag_kind();
1189
1190 if acc_structure.is_dense() {
1191 if lhs_structure.is_dense() && rhs_structure.is_dense() {
1192 super::matmul_with_conj(acc, beta, lhs, conj_lhs, rhs, conj_rhs, alpha.clone(), par);
1193 } else {
1194 debug_assert!(rhs_structure.is_lower());
1195
1196 if lhs_structure.is_dense() {
1197 mat_x_lower_impl_unchecked(
1198 acc.as_shape_mut(M, N),
1199 beta,
1200 lhs.as_shape(M, N),
1201 rhs.as_shape(N, N),
1202 rhs_diag,
1203 alpha,
1204 conj_lhs,
1205 conj_rhs,
1206 par,
1207 )
1208 } else if lhs_structure.is_lower() {
1209 clear_upper(acc.rb_mut(), true);
1210 lower_x_lower_into_lower_impl_unchecked(
1211 acc.as_shape_mut(N, N),
1212 beta,
1213 false,
1214 lhs.as_shape(N, N),
1215 lhs_diag,
1216 rhs.as_shape(N, N),
1217 rhs_diag,
1218 alpha,
1219 conj_lhs,
1220 conj_rhs,
1221 par,
1222 );
1223 } else {
1224 debug_assert!(lhs_structure.is_upper());
1225 upper_x_lower_impl_unchecked(
1226 acc.as_shape_mut(N, N),
1227 beta,
1228 lhs.as_shape(N, N),
1229 lhs_diag,
1230 rhs.as_shape(N, N),
1231 rhs_diag,
1232 alpha,
1233 conj_lhs,
1234 conj_rhs,
1235 par,
1236 )
1237 }
1238 }
1239 } else if acc_structure.is_lower() {
1240 if lhs_structure.is_dense() && rhs_structure.is_dense() {
1241 mat_x_mat_into_lower_impl_unchecked(
1242 acc.as_shape_mut(N, N),
1243 beta,
1244 skip_diag,
1245 lhs.as_shape(N, K),
1246 rhs.as_shape(K, N),
1247 alpha,
1248 conj_lhs,
1249 conj_rhs,
1250 par,
1251 )
1252 } else {
1253 debug_assert!(rhs_structure.is_lower());
1254 if lhs_structure.is_dense() {
1255 mat_x_lower_into_lower_impl_unchecked(
1256 acc.as_shape_mut(N, N),
1257 beta,
1258 skip_diag,
1259 lhs.as_shape(N, N),
1260 rhs.as_shape(N, N),
1261 rhs_diag,
1262 alpha,
1263 conj_lhs,
1264 conj_rhs,
1265 par,
1266 );
1267 } else if lhs_structure.is_lower() {
1268 lower_x_lower_into_lower_impl_unchecked(
1269 acc.as_shape_mut(N, N),
1270 beta,
1271 skip_diag,
1272 lhs.as_shape(N, N),
1273 lhs_diag,
1274 rhs.as_shape(N, N),
1275 rhs_diag,
1276 alpha,
1277 conj_lhs,
1278 conj_rhs,
1279 par,
1280 )
1281 } else {
1282 upper_x_lower_into_lower_impl_unchecked(
1283 acc.as_shape_mut(N, N),
1284 beta,
1285 skip_diag,
1286 lhs.as_shape(N, N),
1287 lhs_diag,
1288 rhs.as_shape(N, N),
1289 rhs_diag,
1290 alpha,
1291 conj_lhs,
1292 conj_rhs,
1293 par,
1294 )
1295 }
1296 }
1297 } else if lhs_structure.is_dense() && rhs_structure.is_dense() {
1298 mat_x_mat_into_lower_impl_unchecked(
1299 acc.as_shape_mut(N, N).transpose_mut(),
1300 beta,
1301 skip_diag,
1302 rhs.transpose().as_shape(N, K),
1303 lhs.transpose().as_shape(K, N),
1304 alpha,
1305 conj_rhs,
1306 conj_lhs,
1307 par,
1308 )
1309 } else {
1310 debug_assert!(rhs_structure.is_lower());
1311 if lhs_structure.is_dense() {
1312 upper_x_lower_into_lower_impl_unchecked(
1314 acc.as_shape_mut(N, N).transpose_mut(),
1315 beta,
1316 skip_diag,
1317 rhs.transpose().as_shape(N, N),
1318 rhs_diag,
1319 lhs.transpose().as_shape(N, N),
1320 lhs_diag,
1321 alpha,
1322 conj_rhs,
1323 conj_lhs,
1324 par,
1325 )
1326 } else if lhs_structure.is_lower() {
1327 if !skip_diag {
1328 match beta {
1329 Accum::Add => {
1330 for j in 0..N.unbound() {
1331 acc[(j, j)] = acc[(j, j)] + *alpha * lhs[(j, j)] * rhs[(j, j)];
1332 }
1333 },
1334 Accum::Replace => {
1335 for j in 0..N.unbound() {
1336 acc[(j, j)] = *alpha * lhs[(j, j)] * rhs[(j, j)];
1337 }
1338 },
1339 }
1340 }
1341 clear_upper(acc.rb_mut(), true);
1342 } else {
1343 debug_assert!(lhs_structure.is_upper());
1344 upper_x_lower_into_lower_impl_unchecked(
1345 acc.as_shape_mut(N, N).transpose_mut(),
1346 beta,
1347 skip_diag,
1348 rhs.transpose().as_shape(N, N),
1349 rhs_diag,
1350 lhs.transpose().as_shape(N, N),
1351 lhs_diag,
1352 alpha,
1353 conj_rhs,
1354 conj_lhs,
1355 par,
1356 )
1357 }
1358 }
1359}